package mybatis.mapper;

import mybatis.annotation.Param;
import mybatis.config.Configuration;
import mybatis.session.SqlSession;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

/**
 * @author cg
 * @date 2023/6/20 16:03
 * 完成接口方法到sqlSession的解析处理
 */
public class MapperHandler implements InvocationHandler {
    private final SqlSession sqlSession;

    private final Class<?> mapperInterface;

    public MapperHandler(SqlSession sqlSession, Class<?> mapperInterface) {
        this.sqlSession = sqlSession;
        this.mapperInterface = mapperInterface;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        //解析方法的参数
        Map<String, Object> parameters = resolveParameters(method, args);
        //选择真正执行的方法
        return getResult(method, parameters);
    }

    private Object getResult(Method method, Map<String, Object> parameters) {
        Object result = null;
        String statementId = this.mapperInterface.getName() + "." + method.getName();
        Configuration configuration = this.sqlSession.getSqlConfig();
        MapperStatement mapperStatement = configuration.getMapperStatement(statementId);
        //操作类型
        SqlType type = mapperStatement.getSqlCommandType();
        switch (type) {
            //如果是查询则判断查询一个还是多个
            case SELECT: {
                Class<?> returnType = method.getReturnType();
                if (Collection.class.isAssignableFrom(returnType)) {
                    //是集合则查询多个
                    result = sqlSession.selectList(statementId, parameters);
                } else {
                    result = sqlSession.selectOne(statementId, parameters);
                }
                break;
            }
            case UPDATE:
            case INSERT: {
                result = sqlSession.update(statementId, parameters);
                break;
            }
        }
        return result;
    }

    //解析方法上的参数名
    private Map<String, Object> resolveParameters(Method method, Object[] args) {
        Map<String, Object> parameters = new HashMap<>();
        int idx = 0;
        for (Parameter parameter : method.getParameters()) {
            String parameterName = parameter.getName();
            //解析参数名字
            if (parameter.isAnnotationPresent(Param.class)) {
                parameterName = parameter.getAnnotation(Param.class).value();
            }
            Object value = args[idx++];
            Class<?> clazz = value.getClass();
            if (clazz.getClassLoader() == null) {
                //解析java基本类型
                parameters.put(parameterName, value);
            } else {
                //解析自定义的类
                for (Field declaredField : clazz.getDeclaredFields()) {
                    declaredField.setAccessible(true);
                    String fieldName = declaredField.getName();
                    try {
                        parameters.put(parameterName + "." + fieldName, declaredField.get(value));
                    } catch (IllegalAccessException e) {
                        throw new RuntimeException(e);
                    }
                }
            }
        }
        return parameters;
    }
}
